"""
test_merged.py
--------------
测试合并后的模型是否正常工作
每行都有注释，直接改值即可
"""

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# 合并后的模型目录
MODEL_PATH = r"./merged-model"

# 想问模型的问题，可以一次写多个
QUESTIONS = [
    "你是谁？",
    "你能做什么？",
    "给我讲个笑话"
]

# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)

# 加载合并后的模型
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.float16,
    device_map="auto"
)
model.eval()  # 设置评估模式

# 遍历问题生成回答
for q in QUESTIONS:
    prompt = f"用户：{q}\nAI："  # 拼接 prompt
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=100)

    answer = tokenizer.decode(out[0], skip_special_tokens=True)
    print(f"问题：{q}")
    print(f"回答：{answer}\n")
